{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Stable UnCLIP\n",
    "\n",
    "UnCLIPPipeline [kakaobrain/karlo-v1-alpha](https://huggingface.co/kakaobrain/karlo-v1-alpha) provides a prior model that can generate clip image embedding from text. \n",
    "\n",
    "StableDiffusionImageVariationPipeline [lambdalabs/sd-image-variations-diffusers](https://huggingface.co/lambdalabs/sd-image-variations-diffusers) provides a decoder model than can generate images from clip image embedding.\n",
    "\n",
    "This script was contributed by [Ray Wang](https://wrong.wang/) and the notebook by [Parag Ekbote](https://github.com/ParagEkbote)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: diffusers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (0.31.0)\n",
      "Requirement already satisfied: torch in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (2.2.1+cu121)\n",
      "Requirement already satisfied: importlib-metadata in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (8.5.0)\n",
      "Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (3.16.1)\n",
      "Requirement already satisfied: huggingface-hub>=0.23.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.26.3)\n",
      "Requirement already satisfied: numpy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (1.26.4)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2024.11.6)\n",
      "Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2.32.3)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.4.5)\n",
      "Requirement already satisfied: Pillow in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (11.0.0)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (4.12.2)\n",
      "Requirement already satisfied: sympy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (1.13.3)\n",
      "Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.4.2)\n",
      "Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.1.4)\n",
      "Requirement already satisfied: fsspec in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2024.9.0)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.6.77)\n",
      "Requirement already satisfied: packaging>=20.9 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.23.2->diffusers) (24.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.23.2->diffusers) (6.0.2)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.23.2->diffusers) (4.66.6)\n",
      "Requirement already satisfied: zipp>=3.20 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from importlib-metadata->diffusers) (3.21.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch) (3.0.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (2024.8.30)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install diffusers torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "821e3f3a746a45cd9a8cd2ef089bbb42",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6c634cc71f354575893ce4b03b7f4b42",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "text_proj/config.json:   0%|          | 0.00/203 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4574d8a50b634617839b661ebcfd83a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c165d4e19a1543e593422ec37f73456c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model_index.json:   0%|          | 0.00/545 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "521d58177b1b4db1945413a6b2c5c480",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c59ab28cb6f48efa68813a295158795",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d344c435695d44e180e2ff71569c75ea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "scheduler/scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df520ef176994c228c61d843994d22f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "diffusion_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a9e0bb67ca40475cbeef74d9993b02b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/518 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "12a25e3264ee4dc7a7074312827da3bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "safety_checker/config.json:   0%|          | 0.00/5.01k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "41530a4811f94c37bdf7fd229a93a321",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "unet/config.json:   0%|          | 0.00/871 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f8952d9d237347fb83cb43f331ba9257",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vae/config.json:   0%|          | 0.00/595 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b236790e2504042914219f830664eea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "diffusion_pytorch_model.bin:   0%|          | 0.00/3.44G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "01104b340c074a77b0064b6812e93991",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An error occurred while trying to fetch /home/zeus/.cache/huggingface/hub/models--lambdalabs--sd-image-variations-diffusers/snapshots/42bc0ee1726b141d49f519a6ea02ccfbf073db2e/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /home/zeus/.cache/huggingface/hub/models--lambdalabs--sd-image-variations-diffusers/snapshots/42bc0ee1726b141d49f519a6ea02ccfbf073db2e/unet.\n",
      "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.\n",
      "An error occurred while trying to fetch /home/zeus/.cache/huggingface/hub/models--lambdalabs--sd-image-variations-diffusers/snapshots/42bc0ee1726b141d49f519a6ea02ccfbf073db2e/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/zeus/.cache/huggingface/hub/models--lambdalabs--sd-image-variations-diffusers/snapshots/42bc0ee1726b141d49f519a6ea02ccfbf073db2e/vae.\n",
      "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f0b48094cf6347cbbb0685102579f40d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5dee67c8301d4bf4a34e0f161560affb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from diffusers import DiffusionPipeline\n",
    "\n",
    "device = torch.device(\"cpu\" if not torch.cuda.is_available() else \"cuda\")\n",
    "\n",
    "pipeline = DiffusionPipeline.from_pretrained(\n",
    "    \"kakaobrain/karlo-v1-alpha\",\n",
    "    torch_dtype=torch.float16,\n",
    "    custom_pipeline=\"stable_unclip\",\n",
    "    decoder_pipe_kwargs=dict(\n",
    "        image_encoder=None,\n",
    "    ),\n",
    ")\n",
    "pipeline.to(device)\n",
    "\n",
    "prompt = \"a shiba inu wearing a beret and black turtleneck\"\n",
    "random_generator = torch.Generator(device=device).manual_seed(1000)\n",
    "output = pipeline(\n",
    "    prompt=prompt,\n",
    "    width=512,\n",
    "    height=512,\n",
    "    generator=random_generator,\n",
    "    prior_guidance_scale=4,\n",
    "    prior_num_inference_steps=25,\n",
    "    decoder_guidance_scale=8,\n",
    "    decoder_num_inference_steps=50,\n",
    ")\n",
    "\n",
    "image = output.images[0]\n",
    "image.save(\"./shiba-inu.jpg\")\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
